function VAST_preproc_ERP(filename,STARTSTEP,loadMusc)

% Function for preprocessing VAST data for ERP analysis
% INPUTS ---
% filename: of the form 'p0246_a.bdf'
% STARTSTEP: where do you want to start the analysis (see below)
% loadMusc: look for manually detected artifacts from ERP pipeline?
%
% Guide for STARTSTEP: At what preprocessing stage are you starting?
% Note: If STARTSTEP > 0, necessary preceding variables will be loaded by 
% this pipeline's load_analyzed_vars function, load_analyzed_vars_VAST_ERP.
%
%
% Step 0: Read data channel-by-channel, immediately downsample
% Step 1: Filter between 0.5 and 20 Hz
% Step 2: Manually identify muscle artifacts (if loadMusc == 0), or
%         load artifact segments from previous TFR analysis
% Step 3: ICA to remove blinks/saccades
% Step 4: Pull relevant ERP triggers, save ERP definition cell array
% Step 5: Epoch around ERPs
% Step 6: Peak-to-peak artifact rejection
% Step 7: Epoch-level baseline correction
% Step 8: Calculate grand average ERPs


%% User defined variables

settings.do_ica = 1; % run ICA for blinks, or skip it
settings.p2pThresh = 100; % uV, maximum allowable excursion in any channel 
settings.ds2 = 256; % Hz, data will be downsampled to this SR
settings.origFs = 2048;
settings.baseDur = 0.1; % s
settings.BLmethod = 'epoch'; % baseline correct the 100ms preceding each
                             % epoch, or each window (ENC, INT, RET)?
if ~ismember(settings.BLmethod, {'epoch','window'})
    error('invalid baseline correction method')
end

% Check for necessary inputs
if nargin < 3
    loadMusc = 0; % define artifacts manually by default
end
if ~ismember(loadMusc,[0 1])
    error('loadMusc must be either 0 or 1')
end

if nargin < 2
    STARTSTEP = 0;
end
if STARTSTEP < 0 
    STARTSTEP = 0;
elseif STARTSTEP > 8
    error('Specified STARTSTEP exceeds the number of analysis steps')
end

if nargin < 1
    filename = 'p0246_a.bdf';
end


% Set Paths
% NOTE: New raw data paths imply that data are in pXXXX folders
compname = getenv('computername');

settings.filename = filename;
settings.filenameNoExt = settings.filename(1:end-4);

if strcmp(compname,'DESKTOP-BHR0AU7')
    settings.filedirRaw = strcat('E:\ANL\Experiments\DATA\VAST\RAW\',filename(1:end-6),'\');
    settings.filedirAnalyzed = strcat('E:\ANL\Experiments\DATA\VAST\PREPROCESSED\ERP\',...
        settings.filenameNoExt, '\');
    tfrdir = strcat('E:\ANL\Experiments\DATA\VAST\PREPROCESSED\TFR\',...
        settings.filenameNoExt, '\');
elseif strcmp(compname,'MSI')
    % Now using external SSD for laptop analysis -- should be drive E:
    settings.filedirRaw = strcat('E:\ANL\RAW_DATA\VAST\',filename(1:end-6),'\');
    settings.filedirAnalyzed = strcat('E:\ANL\PREPROCESSED_DATA\VAST\ERP\',...
        settings.filenameNoExt, '\');
    tfrdir = strcat('E:\ANL\PREPROCESSED_DATA\VAST\ERP\',...
        settings.filenameNoExt, '\');
elseif strcmp(compname,'RKC-PER-WD-0002')
    settings.filedirRaw = strcat('C:\Users\mnjoroge\Documents\VAST\VAST_Raw_Data\',filename(1:end-6),'\');
    settings.filedirAnalyzed = strcat('C:\Users\mnjoroge\Documents\VAST\VAST_Preprocessed_Data\ERP\',...
        settings.filenameNoExt, '\');
    tfrdir = strcat('C:\Users\mnjoroge\Documents\VAST\VAST_Preprocessed_Data\TFR\',...
        settings.filenameNoExt, '\');
else
    error('Paths have not yet been set for this machine')
    settings.filedirRaw = 'U:\eng_research_anl\Justin\RESULTS\binding_popout\RAW\2stream\';
    settings.filedirAnalyzed = strcat('C:\ANL_Experiments\RESULTS\binding_popout_2stream\Preprocessed_Data\',...
        settings.filenameNoExt, '\');
end
    
if ~exist(settings.filedirAnalyzed,'dir') % does the Analysis folder not exist?
    mkdir(settings.filedirAnalyzed) % create it if not
end

%% Load any previous data and accessory variables
if STARTSTEP > 0
    load_analyzed_vars_VAST_ERP(STARTSTEP,settings.filedirAnalyzed,settings.filenameNoExt);
end

% Finish path setup...
settings.wholefile = strcat(settings.filedirRaw,settings.filename);
save([settings.filedirAnalyzed settings.filenameNoExt '_settings'],...
    'settings','-v7.3');
% Add the current directory containing this function to the path
addpath(genpath(pwd));
% Add FieldTrip toolbox to the path as well
if strcmp(compname,'DESKTOP-BHR0AU7')
    addpath('C:\Users\jtjus\Documents\MATLAB\fieldtrip-20180512\fieldtrip-20180512');
    addpath(genpath('E:\MATLAB\CSDtoolbox'));
elseif strcmp(compname,'MSI')
    addpath('D:\MATLAB\fieldtrip-20180809\fieldtrip-20180809');
    addpath(genpath('D:\MATLAB\CSDtoolbox'));
elseif strcmp(compname,'RKC-PER-WD-0002')
    addpath('C:\Users\mnjoroge\Documents\MATLAB\fieldtrip-20180512\fieldtrip-20180512')
    addpath(genpath('C:\Users\mnjoroge\Documents\MATLAB\CSDtoolbox'))
else
    error('Set FieldTrip path -- then delete this error.')
end
ft_defaults % initialize minimum required FT paths


% Get Biosemi 64 channel locations
cfg = [];
cfg.layout = 'biosemi64.lay'; % layout file located in FieldTrip folder > template > layout
q = ft_prepare_layout(cfg);


%% STEP 0: read in continuous raw data, immediately downsample
% Data will be read in one channel at a time, immediately downsampled for
% feasability

if STARTSTEP == 0
    
    % Settings for file read
    cfg = [];
    cfg.dataset = settings.wholefile;
    cfg.reref = 'no';
    cfg.continuous = 'yes';
    % Settings for ds
    cfg_ds = [];
    cfg_ds.resamplefs = settings.ds2;

    % Due to dataset size, read in one channel at a time...
    % Group 1) datafiles with ALL channels accidentally recorded
    if strcmp(filename, 'p0246_v.bdf')
        chans2read = [65:128, 257:264, 272];
        group = 1;
    % Group 2) datafiles with only A+B recorded
    elseif ismember(filename, {'p2317_a.bdf', 'p2317_v.bdf'})
        chans2read = 1:73;
        group = 2;
    % Group 3) datafiles with A:D recorded, due to broken A+V inputs
    else
        group = 3;
        chans2read = 65:137;
    end
    
    data_collected = cell(length(chans2read), 1);
    for ch = 1:length(chans2read)
        cfg.channel = chans2read(ch);
        temp = ft_preprocessing(cfg);
        % and immediately downsample
        data_collected{ch} = ft_resampledata(cfg_ds, temp);
    end

    % Properly concatenate all the individual channel data
    cfg = [];
    rawData = ft_appenddata(cfg, data_collected{:});

    % Clear unnecessary variables
    clear temp data_collected

    % Now can re-reference the data
    cfg = [];
    cfg.reref = 'yes';
    cfg.continuous = 'yes';
    cfg.refchannel = {'EXG1', 'EXG2'};
    rawData = ft_preprocessing(cfg, rawData);

    % Get the channel labels right
    if group ~= 2
        if ~(sum(contains(rawData.label, 'C32'))==1 &&...
                sum(contains(rawData.label, 'D32'))==1)
            error("chans2read set up for broken A+B inputs, but data doesn't match this.")
        end
    end
    rawData.label(1:64) = q.label(1:64); % use actual labels instead of C+D

    % make the eye movement data bipolar
    rawData.trial{1}(67,:) = rawData.trial{1}(67,:) - rawData.trial{1}(68,:);
    rawData.trial{1}(69,:) = rawData.trial{1}(69,:) - rawData.trial{1}(70,:);

    % Remove mastoid channels, unused EXG channels
    % Note, hard-coded for our 64 channel setup
    cfg = [];
    cfg.channel = rawData.label([1:64,67,69]);

    rawData = ft_preprocessing(cfg, rawData);

    disp(['Saving: ',settings.filedirAnalyzed, settings.filenameNoExt, '_0_raw'])
    save([settings.filedirAnalyzed settings.filenameNoExt '_0_ds'],'rawData','-v7.3');
    
end % STEP 0


%% STEP 1: Filter data for ERP analysis (between 0.5 and 20 Hz)
% Will be accomplished by separate high- and low-pass filters, so
% transition widths can be set separately

if STARTSTEP <= 1

    cfg = [];
    cfg.continuous = 'yes';
    % High-pass filter
    cfg.hpfilter = 'yes';
    cfg.hpfilttype = 'firws';
    cfg.hpfreq = 0.5; % Hz
    cfg.hpfiltdf = 0.5*cfg.hpfreq;
    cfg.hpfiltwintype = 'kaiser';
    % Low-pass filter
    cfg.lpfilter = 'yes';
    cfg.lpfilttype = 'firws';
    cfg.lpfreq = 20; % Hz
    cfg.lpfiltdf = 0.15*cfg.lpfreq;
    cfg.lpfiltwintype = 'kaiser';

    rawData = ft_preprocessing(cfg,rawData);

    disp(['Saving: ',settings.filedirAnalyzed, settings.filenameNoExt, '_1_ds_filt'])
    save([settings.filedirAnalyzed settings.filenameNoExt '_1_ds_filt'], 'rawData','-v7.3');

end % Step 1


%% Step 2: Detect artifact segments for ICA

if STARTSTEP <= 2 && loadMusc == 1
    % ----------------------------------------------
    %%% WORK NEEDED HERE
    % ----------------------------------------------
    error('not set up yet!')
    
% % %     % Try loading the muscle artifacts from the ERP directory  
% % %     try
% % %         load(strcat(erpdir,settings.filenameNoExt,'_STEP3_vars.mat'));
% % %         disp('Visually-definied artifacts loaded from ERP analysis.')
% % %         
% % %         % Save here to facilitate loading later
% % %         save([settings.filedirAnalyzed settings.filenameNoExt '_STEP3_vars.mat'],'badChans','muscArtfct');
% % %     catch
% % %         disp('Loading of visually-defined artifacts FAILED.')
% % %         loadMusc = 0;
% % %     end
end
    
if STARTSTEP <= 2 && loadMusc == 0 % Identify the muscle artifacts by hand
    disp('Mark muscle artifacts and take note of any bad channel labels.')
    input('Press ENTER to continue')
    
    cfg = [];
    cfg.viewmode = 'vertical';
    cfg.blocksize = 30;
    cfg.ylim = [-30 30];
   
    cfg = ft_databrowser(cfg, rawData);

    % Save a note of bad channels
    temp = inputdlg('Enter SPACE-SEPARATED bad channel LABELS:',...
        'Any bad channels to interpolate?', [1 50]);
    temp = char(temp);
    badChans = strsplit(temp,' ');
    clear temp

    % SAVE 
    muscArtfct = cfg.artfctdef.visual.artifact;
    save([settings.filedirAnalyzed settings.filenameNoExt '_STEP2_vars.mat'],...
         'badChans','muscArtfct');
 
end % Step 2


%% STEP 3: ICA to remove blinks/saccades

if STARTSTEP <= 3
    
    % TEMPORARILY reject the identified muscle artifacts
    % Note: This is done for ICA cleanliness. Good ICA components will be 
    % back-projected onto the previous version of rawData.
    cfg = [];
    cfg.artfctdef.reject = 'partial';
    cfg.artfctdef.muscle.artifact = muscArtfct;
    rawData = ft_rejectartifact(cfg, rawData);
    
    % Select good channels and run the ICA
    cfg = [];
    allChans = rawData.label;
    isChanOk = cell2mat(cellfun(@(x) ~ismember(x,badChans), allChans, 'UniformOutput', false));
    cfg.channel = rawData.label(isChanOk);
    cfg.method = 'runica';

    components = ft_componentanalysis(cfg,rawData);
  
    % get rid of the dummy data -- will reload the unaltered data afterward
    clear rawData
    
    % Fig 1: identified components 1:32
    cfg = [];
    cfg.component = 1:32;
    cfg.layout = 'biosemi64.lay';
    cfg.comment = 'no';
    figure();
    ft_topoplotIC(cfg,components)
    % Fig 2: identified components 33:64
    cfg.component = 33:max(size(components.topo));
    figure();
    ft_topoplotIC(cfg,components)
    
    % Pause here to let the user check out the component topographies
    input('Press ENTER to continue to Trial View')
    
    close all
    
    % Visualize components as time series
    cfg = [];
    cfg.layout = 'biosemi64.lay';
    cfg.viewmode = 'component';

    ft_databrowser(cfg, components)
    
    % Pause here to let the user check out the component time series
    input('Press ENTER to continue, specify blink components.')
    
    % Gather the identified components
    temp = inputdlg('Enter space-separated component numbers:',...
                 'Identify Blink Components', [1 50]);
    blinkCmp = str2num(temp{:});
    clear temp

    % Remove the bad components and backproject onto the unaltered data
    load([settings.filedirAnalyzed settings.filenameNoExt '_1_ds_filt.mat']);

    cfg = [];
    cfg.component = blinkCmp;
    rawData = ft_rejectcomponent(cfg, components, rawData);
    
    % ------------
    % Do some additional cleanup at this stage
    % ------------
    
    % Remove the EXG channels -- they won't be used further
    cfg = [];
    cfg.channel = rawData.label(~strcmp(rawData.label,'EXG3') & ~strcmp(rawData.label,'EXG5'));
    rawData = ft_preprocessing(cfg,rawData);
    
    % Prepare the neighborhood structure for interpolating any bad channels
    cfg = [];
    cfg.method = 'distance';
    cfg.layout = 'biosemi64.lay';   
    neighbours = ft_prepare_neighbours(cfg,rawData);
    
    % Interpolate bad channels
    cfg = [];
    cfg.method = 'spline';
    cfg.badchannel = badChans';
    cfg.neighbours = neighbours;
    cfg.layout = 'biosemi64.lay';
    % NOTE: apparently best practice is to load individualized electrode
    % positions from the polhemus data. Incorporate this down the line. See
    % documentation for ft_channelrepair and ft_read_sens
    rawData = ft_channelrepair(cfg,rawData);

    save([settings.filedirAnalyzed settings.filenameNoExt '_STEP3_vars'],...
        'components','neighbours','-v7.3');
    disp(['Saving: ',settings.filedirAnalyzed, settings.filenameNoExt, '_3_ds_filt_ica'])
    save([settings.filedirAnalyzed settings.filenameNoExt '_3_ds_filt_ica'],'rawData','-v7.3');
    
end % Step 3


%% Step 4: Pull relevant ERP triggers, save ERP definition cell array

if STARTSTEP <= 4
    
    % Read events from bdf file, format them
    eventTrigs = ft_read_event(settings.wholefile);
    eVals = extractfield(eventTrigs,'value');
    eSamps = extractfield(eventTrigs,'sample');
    eTypes = extractfield(eventTrigs,'type');
    realEvents = strcmp(eTypes,'STATUS');
    if sum(realEvents) == length(eVals) % under normal conditions should always be true
        eSamps = eSamps(realEvents);
    end
    
    % Downsample triggers
    eSamps = round((settings.ds2 * eSamps) ./ settings.origFs);
    
    % Reconstruct any missing trigger values from the behavioral cfg
    if ~isequal(sum(eVals==65290),sum(eVals==65291),sum(eVals==65292),...
            sum(eVals==65293),sum(eVals==65300),sum(eVals==65301),...
            sum(eVals==65302),sum(eVals==65303))

        warning('Data is missing some trial triggers... reconstructing them from existing triggers.')
        [eVals,eSamps] = repair_broken_trigs(eSamps,eVals,settings);
    end
    
    % Set all relevant trigger values for ERP analysis
    trigs.enc.s1 = eSamps(eVals == 65290);
    trigs.enc.s2 = eSamps(eVals == 65291);
    trigs.enc.s3 = eSamps(eVals == 65292);
    trigs.enc.s4 = eSamps(eVals == 65293);
    trigs.int.s1 = eSamps(eVals == 65310);
    trigs.int.s2 = eSamps(eVals == 65311);
    trigs.int.s3 = eSamps(eVals == 65312);
    trigs.ret.s1 = eSamps(eVals == 65300);
    trigs.ret.s2 = eSamps(eVals == 65301);
    trigs.ret.s3 = eSamps(eVals == 65302);
    trigs.ret.s4 = eSamps(eVals == 65303);

    % special correction for one dataset with extra trials
    if strcmp(settings.filenameNoExt, 'p6554_a')
        trigs.enc.s1 = trigs.enc.s1(21:end);
        trigs.enc.s2 = trigs.enc.s2(21:end);
        trigs.enc.s3 = trigs.enc.s3(21:end);
        trigs.enc.s4 = trigs.enc.s4(21:end);
        trigs.int.s1 = trigs.int.s1(21:end);
        trigs.int.s2 = trigs.int.s2(21:end);
        trigs.int.s3 = trigs.int.s3(21:end);
        trigs.ret.s1 = trigs.ret.s1(21:end);
        trigs.ret.s2 = trigs.ret.s2(21:end);
        trigs.ret.s3 = trigs.ret.s3(21:end);
        trigs.ret.s4 = trigs.ret.s4(21:end);
    end

    % Define the trial structure; get definition of ERP positions in trial
    [trl, ERPtype, trialAssign] = ft_trialfun_VAST_ERP(trigs, settings);
    
    disp(['Saving: ',settings.filedirAnalyzed, settings.filenameNoExt, '_STEP4_vars'])
    save([settings.filedirAnalyzed settings.filenameNoExt '_STEP4_vars'],...
        'eventTrigs', 'trigs', 'trl', 'ERPtype', 'trialAssign', '-v7.3');

end % Step 4


%% Step 5: Epoch around ERP event triggers

if STARTSTEP <= 5
    
    % epoch based on trl
    cfg = [];
    cfg.trl = trl;
    rawData = ft_redefinetrial(cfg, rawData);
    
    % if epoching at the window level, replace ERP baselines with first 
    % stimuli in each window
    if strcmp(settings.BLmethod, 'window')
        [rawData, cantReconstruct] =...
            replace_baselines('yes', rawData, ERPtype, trialAssign, settings);
    elseif strcmp(settings.BLmethod, 'epoch')
        [rawData, cantReconstruct] =...
            replace_baselines('no', rawData, ERPtype, trialAssign, settings);
    end
 
    % The number of epochs prior to peak-to-peak artifact rejection
    origNtrl = length(rawData.trial);

    disp(['Saving: ',settings.filedirAnalyzed, settings.filenameNoExt, '_5_ds_filt_ica_ep'])
    save([settings.filedirAnalyzed settings.filenameNoExt '_STEP5_vars'],...
        'origNtrl', 'cantReconstruct', '-v7.3');
    save([settings.filedirAnalyzed settings.filenameNoExt '_5_ds_filt_ica_ep'],'rawData','-v7.3');

end % Step 5


%% Step 6: Peak-to-peak artifact rejection

if STARTSTEP <= 6
    
    badEpochs = zeros(size(rawData.trial));
    % if baseline couldn't be found, whole window is bad
    for t = 1:trialAssign(end)
        if cantReconstruct.enc(t) == 1
            badEpochs(trialAssign == t & ismember(ERPtype,...
                {'enc_s1','enc_s2','enc_s3','enc_s4'})) = 1;
        end
        if cantReconstruct.int(t) == 1
            badEpochs(trialAssign == t & ismember(ERPtype,...
                {'int_s1','int_s2','int_s3'})) = 1;
        end
        if cantReconstruct.ret(t) == 1
            badEpochs(trialAssign == t & ismember(ERPtype,...
                {'ret_s1','ret_s2','ret_s3','ret_s4'})) = 1;
        end
    end
    % bad epoch if peak-to-peak threshold is exceeded
    for rTr = 1:length(badEpochs)
        currdata = rawData.trial{rTr};
        % get minima and maxima in each channel
        minVals = min(currdata, [], 2);
        maxVals = max(currdata, [], 2);
        p2p = maxVals - minVals;
        % mark as bad if threshold is exceeded
        if any(p2p >= settings.p2pThresh)
            badEpochs(rTr) = 1;
        end
    end
    badEpochs = logical(badEpochs);
    
    fprintf('Automatic Rejection Rate: %0.02f \n', sum(badEpochs)/length(badEpochs))
    
    % Rid yourself of the bad epochs
    cfg = [];
    goodTrials = ~ismember(1:length(badEpochs),find(badEpochs));
    cfg.trials = goodTrials;
    rawData = ft_selectdata(cfg, rawData);

    % Clean the trial definition...
    trl_cln = trl(~badEpochs,:);
    % ... and the other reference variables
    ERPtype_cln = ERPtype(~badEpochs);
    trialAssign_cln = trialAssign(~badEpochs);
    clear ERPtype trialAssign
    
    % SAVE
    disp(['Saving: ',settings.filedirAnalyzed, settings.filenameNoExt, '_6_ds_filt_ica_ep_clean'])
    
    save([settings.filedirAnalyzed settings.filenameNoExt '_STEP6_vars'],...
        'badEpochs', 'trl_cln', 'ERPtype_cln', 'trialAssign_cln', '-v7.3');
    save([settings.filedirAnalyzed settings.filenameNoExt '_6_ds_filt_ica_ep_clean'],'rawData','-v7.3');
    
end


%% Step 7: Epoch-level baseline correction

if STARTSTEP <= 7
    
    % First, identify the baseline window (consistent by trl definition)
    [~,blend] = min(abs(settings.baseDur - rawData.time{1}));
    baseline = 1:blend;

    for epoch = 1:length(rawData.trial)
        % pull the signal for each epoch
        tempdata = rawData.trial{epoch};   
        % get the average baseline at each channel
        blfactors = mean(tempdata(:,baseline),2);
        tempdata = tempdata - blfactors; % subtract off each channel's baseline
        % replace the baseline corrected data
        rawData.trial{epoch} = tempdata;
    end
    
    disp(['Saving: ',settings.filedirAnalyzed, settings.filenameNoExt, '_7_ds_filt_ica_ep_clean_demean'])
    save([settings.filedirAnalyzed settings.filenameNoExt '_7_ds_filt_ica_ep_clean_demean'],'rawData','-v7.3');
    
end


%% Step 8: Calculate grand average ERPs

if STARTSTEP <= 8
    
    trlConds = get_trial_conditions(settings);
    
    % Set up and preallocate the final output
    uniqConds = unique(trlConds);
    uniqTypes = unique(ERPtype_cln);
    uniqTypes_noInt = uniqTypes(~contains(uniqTypes, 'int'));
    [avgERPs, avgERPs_first3] = deal([uniqConds; cell(size(uniqConds))]);
    for i = 1:length(avgERPs)
        if contains(uniqConds{i}, 'none')
            avgERPs{2,i} = [uniqTypes_noInt; cell(2, length(uniqTypes_noInt))];
        else
            avgERPs{2,i} = [uniqTypes; cell(2, length(uniqTypes))];
        end
        avgERPs_first3{2,i} = [{'enc','ret'}; cell(2,2)];
    end
    
    cfg = [];
    for c = 1:length(uniqConds)
        trlCorrectCond = find(strcmp(trlConds, uniqConds{c}));
        % --- Each position separately ---
        for t = 1:length(avgERPs{2,c})
            % check for ERP type and condition match
            ERPfind = strcmp(ERPtype_cln, avgERPs{2,c}{1,t}) &...
                      ismember(trialAssign_cln, trlCorrectCond);
            ERPfind = check_for_trl_nans(ERPfind, rawData);
            cfg.trials = ERPfind;
            avgERPs{2,c}{2,t} = sum(ERPfind); % ERP counts
            avgERPs{2,c}{3,t} = ft_timelockanalysis(cfg, rawData);
            % correct timebases while we're here
            avgERPs{2,c}{3,t}.time = avgERPs{2,c}{3,t}.time - settings.baseDur;
        end

        % --- Average ERPs in positions 1-3 together ---
        for t = 1:size(avgERPs_first3{2,c}, 2)
            if t == 1 % ENC ERPs
                currTypes = {'enc_s1','enc_s2','enc_s3'};
            elseif t == 2 % RET ERPs
                currTypes = {'ret_s1','ret_s2','ret_s3'};
            end
            ERPfind = ismember(ERPtype_cln, currTypes) &...
                      ismember(trialAssign_cln, trlCorrectCond);
            cfg.trials = ERPfind;
            avgERPs_first3{2,c}{2,t} = sum(ERPfind); % ERP count\
            avgERPs_first3{2,c}{3,t} = ft_timelockanalysis(cfg, rawData);
            % correct timebases while we're here
            avgERPs_first3{2,c}{3,t}.time =...
                avgERPs_first3{2,c}{3,t}.time - settings.baseDur;
        end
    end
    
    % Only save the actual data
    for i = 1:size(avgERPs,2)
        for j1 = 1:size(avgERPs{2,i},2)
            avgERPs{2,i}{3,j1} = avgERPs{2,i}{3,j1}.avg;
        end
        for j2 = 1:size(avgERPs_first3{2,i},2)
            avgERPs_first3{2,i}{3,j2} = avgERPs_first3{2,i}{3,j2}.avg;
        end
    end
    
    disp(['Saving: ',settings.filedirAnalyzed, settings.filenameNoExt, '_avgERPs'])
    save([settings.filedirAnalyzed settings.filenameNoExt '_avgERPs'],'avgERPs');
    save([settings.filedirAnalyzed settings.filenameNoExt '_avgERPs_first3'],...
        'avgERPs_first3');

end % Step 8

end % Preprocessing function


%% Helper function to get trial conditions from experiment cfg
function trlConds = get_trial_conditions(settings)

% Load the corresponding experiment cfg
allfiles = dir(settings.filedirRaw);
allfiles = extractfield(allfiles, 'name');
allfiles = allfiles(3:end);
isSubj = contains(allfiles, settings.filenameNoExt);
allfiles = allfiles(isSubj);
isMat = contains(allfiles, ".mat");
if sum(isMat) > 1
    error("more than one experiment .mat file -- not sure which to load")
end
behavfile = allfiles{isMat};
expcfg = load(strcat(settings.filedirRaw, behavfile), "cfg");
expcfg = expcfg.cfg;

% figure out condition assignment 
trlConds = cell(1, expcfg.nBlocks * expcfg.nTrials);
for b = 1:expcfg.nBlocks
    m = expcfg.m1{b};
    d = expcfg.d{b};
    int = expcfg.dist{b};
    currcond = strcat(m, d, '_', int);
    inds = ((b-1)*expcfg.nTrials) + 1 : b * expcfg.nTrials;
    trlConds(inds) = {currcond};
end

end


%% Helper function to replace baseline data with first stim in each window
function [rawData, cantReconstruct] =...
    replace_baselines(replace, rawData, ERPtype, trialAssign, settings)
%
% note: replace argument just added for variable consistency, so
% cantReconstruct is always created

if ~ismember(replace, {'yes','no'})
    error("replace argument should be either 'yes' or 'no'")
end

Ntrl = trialAssign(end);
[cantReconstruct.enc, cantReconstruct.int, cantReconstruct.ret] =...
    deal(zeros(1,Ntrl));

if strcmp(replace, 'yes')
    for t = 1:Ntrl
        currTrls = find(trialAssign == t);

        % Encoding window
        epoch_enc = contains(ERPtype(currTrls), 'enc_s1');
        if any(epoch_enc)
            currBL = rawData.trial{currTrls(epoch_enc)};
            currBL = currBL(:, rawData.time{currTrls(epoch_enc)} < settings.baseDur);
            if any(contains(ERPtype(currTrls), 'enc_s2'))
                epoch = currTrls(contains(ERPtype(currTrls), 'enc_s2'));
                erp = rawData.trial{epoch};
                erp(:, 1:size(currBL,2)) = currBL;
                rawData.trial{epoch} = erp;
            end
            if any(contains(ERPtype(currTrls), 'enc_s3'))
                epoch = currTrls(contains(ERPtype(currTrls), 'enc_s3'));
                erp = rawData.trial{epoch};
                erp(:, 1:size(currBL,2)) = currBL;
                rawData.trial{epoch} = erp;
            end
            if any(contains(ERPtype(currTrls), 'enc_s4'))
                epoch = currTrls(contains(ERPtype(currTrls), 'enc_s4'));
                erp = rawData.trial{epoch};
                erp(:, 1:size(currBL,2)) = currBL;
                rawData.trial{epoch} = erp;
            end
        else
            warning('no suitable baseline for trial %d, encoding window', t)
            cantReconstruct.enc(t) = 1;
        end

        % Interfering task/retention window
        % outer layer checks whether there is an interfering task
        if any(ismember(ERPtype(currTrls), {'int_s1','int_s2','int_s3'}))
            epoch_enc = contains(ERPtype(currTrls), 'int_s1');
            if any(epoch_enc)
                currBL = rawData.trial{currTrls(epoch_enc)};
                currBL = currBL(:, rawData.time{currTrls(epoch_enc)} < settings.baseDur);
                if any(contains(ERPtype(currTrls), 'int_s2'))
                    epoch = currTrls(contains(ERPtype(currTrls), 'int_s2'));
                    erp = rawData.trial{epoch};
                    erp(:, 1:size(currBL,2)) = currBL;
                    rawData.trial{epoch} = erp;
                end
                if any(contains(ERPtype(currTrls), 'int_s3'))
                    epoch = currTrls(contains(ERPtype(currTrls), 'int_s3'));
                    erp = rawData.trial{epoch};
                    erp(:, 1:size(currBL,2)) = currBL;
                    rawData.trial{epoch} = erp;
                end
            else
                warning('no suitable baseline for trial %d, interfering task window', t)
                cantReconstruct.int(t) = 1;
            end
        end

        % Retrieval window
        epoch_enc = contains(ERPtype(currTrls), 'ret_s1');
        if any(epoch_enc)
            currBL = rawData.trial{currTrls(epoch_enc)};
            currBL = currBL(:, rawData.time{currTrls(epoch_enc)} < settings.baseDur);
            if any(contains(ERPtype(currTrls), 'ret_s2'))
                epoch = currTrls(contains(ERPtype(currTrls), 'ret_s2'));
                erp = rawData.trial{epoch};
                erp(:, 1:size(currBL,2)) = currBL;
                rawData.trial{epoch} = erp;
            end
            if any(contains(ERPtype(currTrls), 'ret_s3'))
                epoch = currTrls(contains(ERPtype(currTrls), 'ret_s3'));
                erp = rawData.trial{epoch};
                erp(:, 1:size(currBL,2)) = currBL;
                rawData.trial{epoch} = erp;
            end
            if any(contains(ERPtype(currTrls), 'ret_s4'))
                epoch = currTrls(contains(ERPtype(currTrls), 'ret_s4'));
                erp = rawData.trial{epoch};
                erp(:, 1:size(currBL,2)) = currBL;
                rawData.trial{epoch} = erp;
            end
        else
            warning('no suitable baseline for trial %d, retrieval window', t)
            cantReconstruct.ret(t) = 1;
        end

    end % trial loop
end % check if replacing baselines

end


%% Helper function to prevent averaging of adta with NaNs

function ERPfind = check_for_trl_nans(ERPfind, rawData)

whereERPs = find(ERPfind);
badERPs = zeros(size(whereERPs));
for i = 1:length(badERPs)
    if any(isnan(rawData.trial{whereERPs(i)}(:)))
        badERPs(i) = 1;
    end
end
badERPs = logical(badERPs);
ERPfind(whereERPs(badERPs)) = false;

end
